Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache #11277

Merged
merged 2 commits into from
Jan 28, 2025

Conversation

liangfu
Copy link
Contributor

@liangfu liangfu commented Dec 18, 2024

Summary

FIX #11152

This PR introduce a NKI-based kernel that brings the support for chunked-prefill with flash-attention.

Co-authored-by: Jiangfei Duan [email protected]

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@liangfu liangfu force-pushed the nki-flash-attn branch 3 times, most recently from 188fc99 to b63906b Compare December 19, 2024 06:50
@mergify mergify bot added the ci/build label Jan 6, 2025
Copy link

mergify bot commented Jan 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @liangfu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 7, 2025
@liangfu liangfu marked this pull request as ready for review January 7, 2025 06:39
@mergify mergify bot removed the needs-rebase label Jan 7, 2025
@liangfu liangfu changed the title [Draft][Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache [Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache Jan 7, 2025
Co-authored-by: Jiangfei Duan <[email protected]>
Signed-off-by: Liangfu Chen <[email protected]>
@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 9, 2025
@simon-mo
Copy link
Collaborator

simon-mo commented Jan 9, 2025

@robertgshaw2-neuralmagic @WoosukKwon PTAL

@robertgshaw2-redhat robertgshaw2-redhat self-assigned this Jan 11, 2025
return o


def flash_attn_varlen_nkifunc(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could align this to the API of the unified flash attention funciton for V1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. in v1/attention/backends/flash_attention

flash_attn_varlen_func(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
                cu_seqlens_q=attn_metadata.query_start_loc,
                max_seqlen_q=attn_metadata.max_query_len,
                cu_seqlens_k=attn_metadata.seq_start_loc,
                max_seqlen_k=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                block_table=attn_metadata.block_table,
                softcap=self.logits_soft_cap,
            )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligning to this interface will reduce special cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question!

There are three challenges that are blocking the alignment of the interface:

  1. calling reshape_and_cache_flash before flash_attn_varlen_func can be inefficient read-after-write, due to synchronous process on a heterogeneous architecture. On neuron stack, it's better to merge cached tokens with new tokens in an asynchronous fashion. Therefore, we pass both cached KV and active KV to the flash-attention function call, so that writing to HBM won't block compuation.
  2. Since slicing num_actual_tokens would not reduce computation or bandwidth utilization, I think it's better to slice the logits in the last attention layer.
  3. We find it more efficient to compuate attention mask before each of the layer. Therefore, instead of passing sequence lengths and query lengths, we pass the pre-computed attention mask directly to the flash-attention kernel.

I'm actively looking for ideas that can help close the gap between the interfaces, without degrading performance of the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat do you think the interface misalignment can be a blocking issue for merging?

@robertgshaw2-redhat
Copy link
Collaborator

Nice work!

@liangfu
Copy link
Contributor Author

liangfu commented Jan 16, 2025

@simon-mo @WoosukKwon do you have any other concerns / comments ?

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very isolated change. I'm okay with merging this

@simon-mo simon-mo merged commit ddee88d into vllm-project:main Jan 28, 2025
55 checks passed
None), "continuous_batching_mask does not support logit_bias!"

# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arthimetic intensity by half
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a typo arthimetic breaks the CI.

@mgoin
Copy link
Member

mgoin commented Jan 28, 2025

PR to fix the failing precommit #12497

tjtanaa pushed a commit to EmbeddedLLM/vllm that referenced this pull request Jan 28, 2025
rasmith pushed a commit to rasmith/vllm that referenced this pull request Jan 30, 2025
Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Feb 2, 2025
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
GWS0428 pushed a commit to GWS0428/VARserve that referenced this pull request Feb 12, 2025
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC][Exploratory]: vLLM Neuron Backend with V1 Architecture
6 participants